#!/usr/bin/python
# -*- coding: utf-8 -*-

import warnings,numpy,argparse,re,sys,os,os.path,unicodedata,multiprocessing,codecs
from collections import Counter
import ocrolib
from pylab import *
from scipy.ndimage import filters

# disable rank warnings from polyfit
warnings.simplefilter('ignore',numpy.RankWarning) 

parser = argparse.ArgumentParser(description = """
Compute the edit distances between ground truth and recognizer output.
Run with the ground truth files as arguments, and it will find the
corresponnding recognizer output files using the given extension (-x).
Missing output files are handled as empty strings, unless the -s
option is given.
""")
parser.add_argument("files",default=[],nargs='*',help="input lines")
parser.add_argument("-x","--extension",default=".txt",help="extension for recognizer output, default: %(default)s")

parser.add_argument("-k","--kind",default="exact",help="kind of comparison (exact, nospace, letdig, letters, digits, lnc), default: %(default)s")
parser.add_argument("-s","--skipmissing",action="store_true",help="don't use missing or empty output files in the calculation")

parser.add_argument("-c","--confusion",default=0,type=int,help="output this many top confusion")
parser.add_argument("-a","--allconf",default=None,help="output all confusions to this file")
parser.add_argument("-e","--perfile",default=None,help="output per-file errors to this file")
parser.add_argument("-C","--context",type=int,default=0,help="context for confusion matrix")
parser.add_argument("-Q","--parallel",type=int,default=multiprocessing.cpu_count())
args = parser.parse_args()
args.files = ocrolib.glob_all(args.files)

def levenshtein(a,b):
    """Calculates the Levenshtein distance between a and b. 
    (Clever compact Pythonic implementation from hetland.org)"""
    n, m = len(a), len(b)
    if n > m: a,b = b,a; n,m = m,n       
    current = range(n+1)
    for i in range(1,m+1):
        previous,current = current,[i]+[0]*n
        for j in range(1,n+1):
            add,delete = previous[j]+1,current[j-1]+1
            change = previous[j-1]
            if a[j-1]!=b[i-1]: change = change+1
            current[j] = min(add, delete, change)
    return current[n]

def xlevenshtein(a,b,context=1):
    n, m = len(a), len(b)
    sources = empty((m+1,n+1),object)
    sources[:,:] = None
    dists = 99999*ones((m+1,n+1))
    dists[0,:] = arange(n+1)
    for i in range(1,m+1):
        previous = dists[i-1,:]
        current = dists[i,:]
        current[0] = i
        for j in range(1,n+1):
            if previous[j]+1<current[j]:
                sources[i,j] = (i-1,j)
                dists[i,j] = previous[j]+1
            if current[j-1]+1<current[j]:
                sources[i,j] = (i,j-1)
                dists[i,j] = current[j-1]+1
            delta = 1*(a[j-1]!=b[i-1])
            if previous[j-1]+delta<current[j]:
                sources[i,j] = (i-1,j-1)
                dists[i,j] = previous[j-1]+delta
    cost = current[n]

    # reconstruct the paths and produce two aligned strings
    l = sources[i,n]
    path = []
    while l is not None:
        path.append(l)
        i,j = l
        l = sources[i,j]
    al,bl = [],[]
    path = [(n+2,m+2)]+path
    for k in range(len(path)-1):
        i,j = path[k]
        i0,j0 = path[k+1]
        u = "_"
        v = "_"
        if j!=j0 and j0<n: u = a[j0]
        if i!=i0 and i0<m: v = b[i0]
        al.append(u)
        bl.append(v)
    al = "".join(al[::-1])
    bl = "".join(bl[::-1])

    # now compute a splittable string with the differences
    assert len(al)==len(bl)
    al = " "*context+al+" "*context
    bl = " "*context+bl+" "*context
    assert "~" not in al and "~" not in bl
    same = array([al[i]==bl[i] for i in range(len(al))],'i')
    same = filters.minimum_filter(same,1+2*context)
    als = "".join([al[i] if not same[i] else "~" for i in range(len(al))])
    bls = "".join([bl[i] if not same[i] else "~" for i in range(len(bl))])
    # print als
    # print bls
    ags = re.split(r'~+',als)
    bgs = re.split(r'~+',bls)
    confusions = [(a,b) for a,b in zip(ags,bgs) if a!="" or b!=""]
    return cost,confusions

# this is the same list of replacements we use for
# alignment; it mainly cleans up the quote mess

replacements = [
    (u'"',u"''"),     # typewriter double quote
    (u"`",u"'"),      # grave accent
    (u"\u00b4",u"'"),   # acute accent
    (u"\u2018",u"'"), # left single quotation mark
    (u"\u2019",u"'"), # right single quotation mark
    (u"\u201c",u"''"), # left double quotation mark
    (u"\u201d",u"''"), # right double quotation mark
]

def normalize(s):
    s = re.sub('[_~#]',"''",s)
    s = re.sub('`',"''",s)
    s = re.sub('"',"'",s)
    s = re.sub('[“”]',"''",s)
    s = re.sub(r'\s',' ',s)
    s = re.sub(r'\n','',s)
    s = re.sub(r'^\s+','',s)
    s = re.sub(r'\s+$','',s)
    s = re.sub(r'( *[.] *){4,}','....',s)
    for m,r in replacements:
        s = re.sub(m,r,s)
    s = unicodedata.normalize('NFKD',s)
    if args.kind=="exact":
        return s
    if args.kind=="nospace":
        return re.sub(r'\s','',s)
    if args.kind=="spletdig":
        return re.sub(r'[^A-Za-z0-9 ]','',s)
    if args.kind=="letdig":
        return re.sub(r'[^A-Za-z0-9]','',s)
    if args.kind=="letters":
        return re.sub(r'[^A-Za-z]','',s)
    if args.kind=="digits":
        return re.sub(r'[^0-9]','',s)
    if args.kind=="lnc":
        s = s.upper()
        return re.sub(r'[^A-Z]','',s)
    raise Exception("unknown normalization: "+args.kind)

if not ".gt." in args.files[0]:
    sys.stderr.write("warning: compare on .gt.txt files, not .txt files\n")


def process1(fname):
    # fgt = ocrolib.allsplitext(fname)[0]+args.gtextension
    counts = Counter()
    gt = normalize(ocrolib.read_text(fname))
    ftxt = ocrolib.allsplitext(fname)[0]+args.extension
    missing = 0
    if os.path.exists(ftxt):
        txt = normalize(ocrolib.read_text(ftxt))
    else:
        missing = len(gt)
        txt = ""
    if args.confusion>0 or args.allconf is not None:
        err,cs = xlevenshtein(txt,gt,context=args.context)
        for u,v in cs:
            counts[(u,v)] += 1
    else:
        err = levenshtein(txt,gt)
    #assert err==xerr
    return fname,err,len(gt),missing,counts

outputs = ocrolib.parallel_map(process1,args.files,parallel=args.parallel,chunksize=10)

perfile = None
if args.perfile is not None:
    perfile = codecs.open(args.perfile,"w","utf-8")

allconf = None
if args.allconf is not None:
    allconf = codecs.open(args.allconf,"w","utf-8")

errs = 0
total = 0
missing = 0
counts = Counter()
for fname,e,t,m,c in sorted(outputs):
    errs += e
    total += t
    missing += m
    counts += c
    if perfile is not None:
        perfile.write("%6d\t%6d\t%s\n"%(e,t,fname))
    if allconf is not None:
        for (a,b),v in c.most_common(1000):
            allconf.write("%s\t%s\t%s\n"%(a,b,fname))

if perfile is not None: perfile.close()
if allconf is not None: allconf.close()

sys.stderr.write("errors    %8d\n"%errs)
sys.stderr.write("missing   %8d\n"%missing)
sys.stderr.write("total     %8d\n"%total)
sys.stderr.write("err       %8.3f %%\n"%(errs*100.0/total))
sys.stderr.write("errnomiss %8.3f %%\n"%((errs-missing)*100.0/total))

if args.confusion>0:
    for (a,b),v in counts.most_common(args.confusion):
        print "%d\t%s\t%s"%(v,a,b)

print errs*1.0/total
